import numpy,os
from scipy.optimize import curve_fit
from scipy import constants
from matplotlib import pyplot

pi = numpy.pi
w_x = 2*pi*181
w_y = 2*pi*44
w_z = 2*pi*178
c=constants.c
pi = numpy.pi
h = constants.h
hbar = constants.hbar
u = constants.u
kB = constants.Boltzmann

cwd = os.path.dirname(os.path.abspath(__file__))

mRb = 87.
mCs = 133.

m=(mRb+mCs)*u

T = 1.5e-6

###### Functions #####################
def AverageData(data):
    '''
        A function to take experimental data with repeats in the independent (x)
        variable and return the average and standard error of the dependent (y)
        variable in an array.

        Standard error is given by std_dev/sqrt(N-1) where N is the number of
        repeats.

        Output:
        out = numpy.ndarray((Nx,3)) where Nx is number of x variables

        out[:,0] -> unique values of independent variable (x)
        out[:,1] -> mean values of depdendent variable (y)
        out[:,2] -> standard error of depdendent variable (y)
    '''
    #order the array by first column
    order = numpy.argsort(data[:,0])
    data = data[order,:]
    #find unique occurrances of time, and how many there are
    Times,counts = numpy.unique(data[:,0],return_counts=True)
    # create a counter for where in data we are
    # we will be using that python slices do not include the last index given
    # i.e A[0:5] gives A[0],A[1],A[2],A[3],A[4] but:
    #    A[5:10] gives A[5],A[6],A[7],A[8],A[9] so we can continue from index+c
    index = 0
    #create an output array
    output = numpy.zeros((len(counts),3))
    for i,c in enumerate(counts):
        #number data is a slice from data
        number = data[index:index+c,1]
        #calculate and place data into output array
        output[i,:]=[Times[i],numpy.average(number),
                        numpy.std(number)/numpy.sqrt(c-1)]
        #increase index
        index+=c
    #all done
    return output


data_max = numpy.genfromtxt(cwd+"\\Max Density.csv",delimiter=',')
avg_data_max =AverageData(data_max)

data_half = numpy.genfromtxt(cwd+"\\Half Density.csv",delimiter=',')
avg_data_half =AverageData(data_half)


tmax = 15 #ms
times = numpy.linspace(0,50,100)
fitfn = lambda x,A,y0: y0-x*A
fig = pyplot.figure()
fig2 = pyplot.figure()

ax = fig2.add_subplot(111)

axmax = fig.add_subplot(121)
axmax.set_title("Maximum number")
axmax.errorbar(avg_data_max[:,0],avg_data_max[:,1],yerr=avg_data_max[:,2],fmt='o')

locs = numpy.where(avg_data_max[:,0]<=tmax)[0]

fitdata = avg_data_max[locs,:]
axmax.errorbar(fitdata[:,0],fitdata[:,1],yerr=fitdata[:,2],fmt='o')

curve,cov = curve_fit(fitfn,fitdata[:,0],fitdata[:,1],sigma=fitdata[:,2],
                        absolute_sigma=True,p0=[0.4e11/20,1e11])
err = numpy.sqrt(numpy.diag(cov))
print(curve[0]/fitdata[0,1],(curve[0]/fitdata[0,1])*numpy.sqrt((err[0]/curve[0])**2+(fitdata[0,2]/fitdata[0,1])**2))
#print(curve[0]/curve[1],(curve[0]/curve[1])*numpy.sqrt((err[0]/curve[0])**2+(err[1]/curve[1])**2))


ax.errorbar(curve[1],curve[0],xerr=err[1],yerr=err[0])

axmax.plot(times,fitfn(times,*curve))
axmax.set_ylabel("Molecule Number")
axmax.set_xlabel("Time (ms)")

axhalf = fig.add_subplot(122,sharey=axmax)
axhalf.set_title("Half number")
axhalf.errorbar(avg_data_half[:,0],avg_data_half[:,1],
                yerr=avg_data_half[:,2],fmt='o')

locs = numpy.where(avg_data_half[:,0]<=tmax)[0]

fitdata = avg_data_half[locs,:]
axhalf.errorbar(fitdata[:,0],fitdata[:,1],yerr=fitdata[:,2],fmt='o')

curve,cov = curve_fit(fitfn,fitdata[:,0],fitdata[:,1],sigma=fitdata[:,2],
                        absolute_sigma=True,p0=[0.4e11/20,0.5e11])
err = numpy.sqrt(numpy.diag(cov))
#print(curve[0]/curve[1],(curve[0]/curve[1])*numpy.sqrt((err[0]/curve[0])**2+(err[1]/curve[1])**2))
print(curve[0]/fitdata[0,1],(curve[0]/fitdata[0,1])*numpy.sqrt((err[0]/curve[0])**2+(fitdata[0,2]/fitdata[0,1])**2))

ax.errorbar(curve[1],curve[0],xerr=err[1],yerr=err[0])

axhalf.plot(times,fitfn(times,*curve))

axhalf.set_ylabel("Molecule Number")
axhalf.set_xlabel("Time (ms)")

locs = numpy.where(avg_data_half[:,0]<=tmax)

fitdata = avg_data_half[locs,:]

ax.set_yscale("log")
ax.set_xscale("log")
axmax.set_ylim(0,2500)
axhalf.set_xlim(0,65)
axmax.set_xlim(0,65)
fig.tight_layout()
pyplot.show()
